{
 "cells": [
  {
   "cell_type": "markdown",
   "id": "46f81679",
   "metadata": {},
   "source": [
    "# 사용자 정의 제네레이터(generator)\n",
    "\n",
    "제너레이터 함수(즉, `yield` 키워드를 사용하고 이터레이터처럼 동작하는 함수)를 LCEL 파이프라인에서 사용할 수 있습니다.\n",
    "\n",
    "이러한 제너레이터의 시그니처는 `Iterator[Input] -> Iterator[Output]`이어야 합니다. 비동기 제너레이터의 경우에는 `AsyncIterator[Input] -> AsyncIterator[Output]`입니다.\n",
    "\n",
    "이는 다음과 같은 용도로 유용합니다.\n",
    "\n",
    "- 사용자 정의 출력 파서 구현\n",
    "- 이전 단계의 출력을 수정하면서 스트리밍 기능 유지\n",
    "\n",
    "아래의 예제에서는 쉼표로 구분된 목록에 대한 사용자 정의 출력 파서를 구현해 보겠습니다.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "id": "94adfad2",
   "metadata": {},
   "outputs": [],
   "source": [
    "%pip install -qU langchain langchain-openai"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 31,
   "id": "711f54a9",
   "metadata": {},
   "outputs": [],
   "source": [
    "from typing import Iterator, List\n",
    "\n",
    "from langchain.prompts.chat import ChatPromptTemplate\n",
    "from langchain_core.output_parsers import StrOutputParser\n",
    "from langchain_openai import ChatOpenAI\n",
    "\n",
    "prompt = ChatPromptTemplate.from_template(\n",
    "    # 주어진 회사와 유사한 5개의 회사를 쉼표로 구분된 목록으로 작성하세요.\n",
    "    \"Write a comma-separated list of 5 companies similar to: {company}\"\n",
    ")\n",
    "# 온도를 0.0으로 설정하여 ChatOpenAI 모델을 초기화합니다.\n",
    "model = ChatOpenAI(temperature=0.0, model=\"gpt-4-turbo-preview\")\n",
    "\n",
    "# 프롬프트와 모델을 연결하고 문자열 출력 파서를 적용하여 체인을 생성합니다.\n",
    "str_chain = prompt | model | StrOutputParser()"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4787d9d5",
   "metadata": {},
   "source": [
    "아래의 체인을 **스트림(stream)** 으로 실행하고 결과를 출력합니다. 결과는 실시간으로 생성됩니다.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 32,
   "id": "d8ea3093",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Microsoft, Amazon, Facebook (Meta), Apple, Alibaba"
     ]
    }
   ],
   "source": [
    "# 데이터를 스트리밍합니다.\n",
    "for chunk in str_chain.stream({\"company\": \"Google\"}):\n",
    "    # 각 청크를 출력하고, 줄 바꿈 없이 버퍼를 즉시 플러시합니다.\n",
    "    print(chunk, end=\"\", flush=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "1ae77647",
   "metadata": {},
   "source": [
    "`invoke()` 하여 실행 결과를 확인합니다.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 33,
   "id": "a80c60be",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "'Microsoft, Amazon, Facebook (Meta), Apple, Alibaba'"
      ]
     },
     "execution_count": 33,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 체인에 데이터를 invoke 합니다.\n",
    "str_chain.invoke({\"company\": \"Google\"})"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "396681fe",
   "metadata": {},
   "source": [
    "`split_into_list` 함수는 LLM 토큰의 이터레이터를 입력으로 받아 쉼표로 구분된 문자열 리스트의 이터레이터를 반환합니다.\n",
    "\n",
    "- 마지막 청크를 `yield`하여 반환합니다(generator).\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 34,
   "id": "bcb147d8",
   "metadata": {},
   "outputs": [],
   "source": [
    "# 입력으로 llm 토큰의 반복자를 받아 쉼표로 구분된 문자열 리스트로 분할하는 사용자 정의 파서입니다.\n",
    "def split_into_list(input: Iterator[str]) -> Iterator[List[str]]:\n",
    "    # 쉼표가 나올 때까지 부분 입력을 보관합니다.\n",
    "    buffer = \"\"\n",
    "    for chunk in input:\n",
    "        # 현재 청크를 버퍼에 추가합니다.\n",
    "        buffer += chunk\n",
    "        # 버퍼에 쉼표가 있는 동안 반복합니다.\n",
    "        while \",\" in buffer:\n",
    "            # 버퍼를 쉼표로 분할합니다.\n",
    "            comma_index = buffer.index(\",\")\n",
    "            # 쉼표 이전의 모든 내용을 반환합니다.\n",
    "            yield [buffer[:comma_index].strip()]\n",
    "            # 나머지는 다음 반복을 위해 저장합니다.\n",
    "            buffer = buffer[comma_index + 1 :]\n",
    "    # 마지막 청크를 반환합니다.\n",
    "    yield [buffer.strip()]"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "81604147",
   "metadata": {},
   "source": [
    "- `str_chain` 변수의 문자열을 파이프(`|`) 연산자를 사용하여 `split_into_list` 함수에 전달합니다.\n",
    "- `split_into_list` 함수는 문자열을 리스트로 분할하는 역할을 합니다.\n",
    "- 분할된 리스트는 `list_chain` 변수에 할당됩니다.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 55,
   "id": "6f7103e9",
   "metadata": {},
   "outputs": [],
   "source": [
    "list_chain = str_chain | split_into_list  # 문자열 체인을 리스트로 분할합니다."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "af461c5e",
   "metadata": {},
   "source": [
    "- `list_chain` 객체의 `stream` 메서드를 호출하여 입력 데이터 `{\"animal\": \"bear\"}`에 대한 출력을 생성합니다.\n",
    "- `stream` 메서드는 출력을 청크(chunk) 단위로 반환하며, 각 청크는 반복문을 통해 처리됩니다.\n",
    "- 각 청크는 `print` 함수를 사용하여 즉시 출력되며, `flush=True` 옵션을 통해 버퍼링 없이 출력이 즉시 이루어집니다.\n",
    "\n",
    "이 코드는 `list_chain`을 사용하여 입력 데이터에 대한 출력을 생성하고, 출력을 청크 단위로 처리하여 즉시 출력하는 과정을 보여줍니다.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 39,
   "id": "cdd7adca",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "['Microsoft']\n",
      "['Amazon']\n",
      "['Apple']\n",
      "['Facebook (Meta)']\n",
      "['Alibaba']\n"
     ]
    }
   ],
   "source": [
    "# 생성한 list_chain 이 문제없이 스트리밍되는지 확인합니다.\n",
    "for chunk in list_chain.stream({\"company\": \"Google\"}):\n",
    "    print(chunk, flush=True)  # 각 청크를 출력하고, 버퍼를 즉시 플러시합니다."
   ]
  },
  {
   "cell_type": "markdown",
   "id": "978790f2",
   "metadata": {},
   "source": [
    "이번에는 `list_chain` 에 `invoke()` 로 데이터를 주입합니다. 문제없이 동작하는지 확인해 보겠습니다.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 41,
   "id": "2756743e",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['Microsoft', 'Amazon', 'Facebook (Meta)', 'Apple', 'Alibaba']"
      ]
     },
     "execution_count": 41,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# list_chain 에 데이터를 invoke 합니다.\n",
    "list_chain.invoke({\"company\": \"Google\"})"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "4a41a78b",
   "metadata": {},
   "source": [
    "## 비동기(Asynchronous)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "eca11b62",
   "metadata": {},
   "source": [
    "`asplit_into_list` 함수는 비동기 제너레이터(`AsyncIterator[str]`)를 입력으로 받아 문자열 리스트의 비동기 제너레이터(`AsyncIterator[List[str]]`)를 반환합니다.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 43,
   "id": "c9fc2631",
   "metadata": {},
   "outputs": [],
   "source": [
    "from typing import AsyncIterator\n",
    "\n",
    "\n",
    "# 비동기 함수 정의\n",
    "async def asplit_into_list(input: AsyncIterator[str]) -> AsyncIterator[List[str]]:\n",
    "    buffer = \"\"\n",
    "    # `input`은 `async_generator` 객체이므로 `async for`를 사용\n",
    "    async for chunk in input:\n",
    "        buffer += chunk\n",
    "        while \",\" in buffer:\n",
    "            comma_index = buffer.index(\",\")\n",
    "            yield [\n",
    "                buffer[:comma_index].strip()\n",
    "            ]  # 쉼표를 기준으로 분할하여 리스트로 반환\n",
    "            buffer = buffer[comma_index + 1:]\n",
    "    yield [buffer.strip()]  # 남은 버퍼 내용을 리스트로 반환\n",
    "\n",
    "\n",
    "# alist_chain 과 asplit_into_list 를 파이프라인으로 연결\n",
    "alist_chain = str_chain | asplit_into_list"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "e37aef40",
   "metadata": {},
   "source": [
    "비동기 함수를 사용하여 스트리밍합니다.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 75,
   "id": "e81ac57c",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "['Microsoft']\n",
      "['Amazon']\n",
      "['Facebook (Meta)']\n",
      "['Apple']\n",
      "['Alibaba']\n"
     ]
    }
   ],
   "source": [
    "# async for 루프를 사용하여 데이터를 스트리밍합니다.\n",
    "async for chunk in alist_chain.astream({\"company\": \"Google\"}):\n",
    "    # 각 청크를 출력하고 버퍼를 비웁니다.\n",
    "    print(chunk, flush=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "49f609b1",
   "metadata": {},
   "source": [
    "비동기 chain 에 데이터를 invoke 하여 문제없이 동작하는지 확인합니다.\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 79,
   "id": "e2788110",
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "['Microsoft', 'Amazon', 'Facebook (Meta)', 'Apple', 'Alibaba']"
      ]
     },
     "execution_count": 79,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 리스트 체인을 비동기적으로 호출합니다.\n",
    "await alist_chain.ainvoke({\"company\": \"Google\"})"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "py-test",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
